import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import math
from scipy.spatial import KDTree
# import matplotlib.pyplot as plt
# import pandas as pd
from sklearn.decomposition import PCA
import copy
from utils.format_util import cast_float
from utils.format_util import dup_name_handler
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")


def best_kmeans(data):
    max_k = 5
    best_silhouette = -1
    best_k = 2
    for k in range(2, max_k):
        try:
            model = KMeans(n_clusters=k).fit(data)
            pred = model.predict(data)
            silhouette = abs(silhouette_score(data, pred))
            if silhouette > best_silhouette:
                best_silhouette = silhouette
                best_k = k
        except Exception as e:
            logging.info("catch exception: {}, use {} as best k".format(e, best_k))
            break
    return best_k


def kmeans_outliers(data, action):
    outlier_index = []
    nearest_index = []

    outlier_thres = 1.5
    k = best_kmeans(data)
    model = KMeans(n_clusters=k).fit(data)
    labels = model.labels_
    k_centers = model.cluster_centers_.tolist()
    centers = []
    for label in labels:
        centers.append(k_centers[label])
    dists = np.linalg.norm(data - centers, axis=1)

    sort_dist_index = np.argsort(dists)
    sort_dists = dists[sort_dist_index]
    median_index = sort_dist_index[len(sort_dist_index) // 2]
    diff_dists = sort_dists[median_index + 1:] - sort_dists[median_index:-1]
    sort_diff_dist_index = np.argsort(diff_dists)
    sort_diff_dists = diff_dists[sort_diff_dist_index]
    # plt.hist(dists, bins=20)
    # plt.show()

    for i in range(len(diff_dists) - 1, 0, -1):
        if sort_diff_dists[i - 1] > 0 and sort_diff_dists[i] / sort_diff_dists[i - 1] > outlier_thres:
            max_diff_dist_index = sort_diff_dist_index[i] + median_index + 1
            outlier_index = sort_dist_index[max_diff_dist_index:]
            inlier_index = sort_dist_index[:max_diff_dist_index]
            if action == 'replace_nearest':
                outliers = data[outlier_index]
                inliers = data[inlier_index]
                tree = KDTree(inliers)
                nearest_index = tree.query(outliers)[1]
                nearest_index = inlier_index[nearest_index]
            break
    return outlier_index, nearest_index


def run(df, params):
    max_vis_pt = 50

    outliers = {}
    results = {}
    results["data_vis"] = {}

    cols = eval(params.get("cols"))
    is_append = int(params.get("append"))
    action = params.get("action")
    all_cols = df.columns.tolist()

    data = df[cols].values
    valid_index = np.where(np.squeeze(np.isnan(data).any(axis=1) == 0))[0]
    valid_data = data[valid_index]
    outlier_index, nearest_index = kmeans_outliers(valid_data, action)
    outlier_index = valid_index[outlier_index]
    nearest_index = valid_index[nearest_index]

    if len(outlier_index) > 0:
        for i, col in enumerate(cols):
            outlier = data[outlier_index, i].tolist()
            outliers[col] = list(set(outlier))

        if data.shape[1] > 2:
            mlmodel = PCA(2).fit(valid_data)
            pca_inliers = mlmodel.transform(valid_data)
            inliers = np.full([len(data), 2], np.nan)
            inliers[valid_index] = pca_inliers
        else:
            inliers = copy.deepcopy(data)

        if data.shape[1] > 1:
            nearest_inliers = inliers[nearest_index].tolist()
            inlier_index = list(set(range(len(data))) - set(outlier_index) - set(nearest_index))
            results["data_vis"]['vis_inliers'] = inliers[inlier_index][:max_vis_pt].tolist()
            results["data_vis"]['vis_nearest_inliers'] = nearest_inliers
            results["data_vis"]['vis_outliers'] = inliers[outlier_index].tolist()
            # plt.scatter(inliers[:, 0], inliers[:, 1], c='b')
            # plt.scatter(inliers[outlier_index, 0], inliers[outlier_index, 1], c='r')
            # plt.scatter(nearest_inliers[:, 0], nearest_inliers[:, 1], c='g')
            # plt.show()

        if action == 'replace_nearest':
            data[outlier_index] = data[nearest_index].tolist()
        if action == 'replace_null':
            none_data = np.full([len(outlier_index), len(cols)], np.nan)
            data = data.astype(float)
            data[outlier_index] = none_data

    if action != "no" and is_append:
        for i, col in enumerate(cols):
            new_col = '_'.join([col, '近邻异常检测'])
            new_col = dup_name_handler(new_col, all_cols)

            index = all_cols.index(col) + 1
            all_cols.insert(index, new_col)
            df = df.reindex(columns=all_cols)
            df[new_col] = data[:, i]
    else:
        df[cols] = data
    if action == 'delete':
        df.drop(outlier_index, inplace=True)

    results = cast_float(results)
    if len(outlier_index) > 0:
        results['highlight'] = outliers
    else:
        results["message"] = "无异常值"
    return df, results


# if __name__ == '__main__':
#     params = {}
#     params['source'] = 'dataset._norm_'
#     params['cols'] = "['a', 'b']"
#     params['append'] = '1'
#     params['target'] = 'dataset._norm_result_'
#     params['action'] = 'replace_nearest'
#     cols = eval(params['cols'])
#
#     df = pd.read_csv('/home/igor/zjlab/data/source_data/test_norm.csv')
#     df, results = run(df, params)
#     print(results)
